{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "import torch.optim as optim\n",
    "\n",
    "from torchvision import datasets\n",
    "import torchvision.transforms as transforms\n",
    "import mmcv\n",
    "from itertools import product"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.manual_seed(7)\n",
    "device = 'cuda:0'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "class Network(nn.Module):\n",
    "    def __init__(self):\n",
    "        super(Network, self).__init__()\n",
    "        self.conv1 = nn.Conv2d(in_channels=1, out_channels=6, kernel_size=5)\n",
    "        self.conv2 = nn.Conv2d(in_channels=6, out_channels=12, kernel_size=5)\n",
    "        \n",
    "        self.fc1 = nn.Linear(in_features=12*4*4, out_features=120)\n",
    "        self.fc2 = nn.Linear(in_features=120, out_features=60)\n",
    "        self.out = nn.Linear(in_features=60, out_features=10)\n",
    "    \n",
    "    def forward(self, x):\n",
    "        x = self.conv1(x)\n",
    "        x = F.relu(x)\n",
    "        x = F.max_pool2d(x, kernel_size=2, stride=2)\n",
    "\n",
    "        x = self.conv2(x)\n",
    "        x = F.relu(x)\n",
    "        x = F.max_pool2d(x, kernel_size=2, stride=2)\n",
    "\n",
    "        x = torch.flatten(x, start_dim=1)\n",
    "        x = self.fc1(x)\n",
    "        x = self.fc2(x)\n",
    "        x = self.out(x)\n",
    "\n",
    "        return x"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_set = datasets.FashionMNIST(root='./data', train=True, download=True, transform=transforms.Compose([transforms.ToTensor()]))\n",
    "val_set = datasets.FashionMNIST(root='./data', train=False,download=True, transform=transforms.Compose([transforms.ToTensor()]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[[512, 1024, 8192], [0.01, 0.001, 0.0001, 1e-05], [True, False]]"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# enable tensorboard\n",
    "from torch.utils.tensorboard import SummaryWriter\n",
    "\n",
    "parameters = dict(\n",
    "    batch_size_list = [512, 1024, 1024*8],\n",
    "    lr_list = [.01, .001, .0001, .00001],\n",
    "    shuffle = [True, False]\n",
    ")\n",
    "param_values = [v for v in parameters.values()]\n",
    "param_values"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>] 10/10, 0.2 task/s, elapsed: 58s, ETA:     0s\n",
      "[>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>] 10/10, 0.2 task/s, elapsed: 59s, ETA:     0s\n",
      "[>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>] 10/10, 0.2 task/s, elapsed: 59s, ETA:     0s\n",
      "[>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>] 10/10, 0.2 task/s, elapsed: 59s, ETA:     0s\n",
      "[>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>] 10/10, 0.2 task/s, elapsed: 59s, ETA:     0s\n",
      "[>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>] 10/10, 0.2 task/s, elapsed: 59s, ETA:     0s\n",
      "[>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>] 10/10, 0.2 task/s, elapsed: 60s, ETA:     0s\n",
      "[>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>] 10/10, 0.2 task/s, elapsed: 59s, ETA:     0s\n",
      "[>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>] 10/10, 0.2 task/s, elapsed: 56s, ETA:     0s\n",
      "[>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>] 10/10, 0.2 task/s, elapsed: 56s, ETA:     0s\n",
      "[>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>] 10/10, 0.2 task/s, elapsed: 55s, ETA:     0s\n",
      "[>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>] 10/10, 0.2 task/s, elapsed: 55s, ETA:     0s\n",
      "[>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>] 10/10, 0.2 task/s, elapsed: 55s, ETA:     0s\n",
      "[>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>] 10/10, 0.2 task/s, elapsed: 55s, ETA:     0s\n",
      "[>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>] 10/10, 0.2 task/s, elapsed: 55s, ETA:     0s\n",
      "[>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>] 10/10, 0.2 task/s, elapsed: 55s, ETA:     0s\n",
      "[>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>] 10/10, 0.2 task/s, elapsed: 53s, ETA:     0s\n",
      "[>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>] 10/10, 0.2 task/s, elapsed: 53s, ETA:     0s\n",
      "[>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>] 10/10, 0.2 task/s, elapsed: 52s, ETA:     0s\n",
      "[>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>] 10/10, 0.2 task/s, elapsed: 53s, ETA:     0s\n",
      "[>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>] 10/10, 0.2 task/s, elapsed: 53s, ETA:     0s\n",
      "[>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>] 10/10, 0.2 task/s, elapsed: 53s, ETA:     0s\n",
      "[>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>] 10/10, 0.2 task/s, elapsed: 53s, ETA:     0s\n",
      "[>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>] 10/10, 0.2 task/s, elapsed: 53s, ETA:     0s\n"
     ]
    }
   ],
   "source": [
    "epochs = 10\n",
    "\n",
    "for batch_size, lr, shuffle in product(*param_values):\n",
    "    model = Network().to(device)\n",
    "    \n",
    "    train_loader = torch.utils.data.DataLoader(train_set, batch_size=batch_size)\n",
    "    val_loader = torch.utils.data.DataLoader(val_set, batch_size=batch_size)\n",
    "    optimizer = optim.Adam(model.parameters(), lr=lr)\n",
    "    \n",
    "    comment = f'_batch_size={batch_size}_lr={lr}_shuffle={shuffle}'\n",
    "    writer = SummaryWriter(comment=comment)\n",
    "    \n",
    "    for epoch in mmcv.track_iter_progress(range(epochs)):\n",
    "        correct_train, loss_train = 0., 0.\n",
    "        for images, labels in (train_loader):\n",
    "            images, labels = images.to(device), labels.to(device)\n",
    "            preds = model(images)\n",
    "            loss = F.cross_entropy(preds, labels)\n",
    "            loss_train += loss.item()\n",
    "            correct_train += (preds.argmax(dim=1) == labels).sum()\n",
    "\n",
    "            optimizer.zero_grad()\n",
    "            loss.backward()\n",
    "            optimizer.step()\n",
    "\n",
    "        correct_val, loss_val = 0., 0.\n",
    "        with torch.no_grad():\n",
    "            for images, labels in (val_loader):\n",
    "                images, labels = images.to(device), labels.to(device)\n",
    "                preds = model(images)\n",
    "                loss = F.cross_entropy(preds, labels)\n",
    "                loss_val += loss.item()\n",
    "                correct_val += (preds.argmax(dim=1) == labels).sum()\n",
    "\n",
    "        acc_train = correct_train/len(train_set)\n",
    "        acc_val = correct_val/len(val_set)\n",
    "\n",
    "        writer.add_scalar('Loss/train', loss_train, epoch)\n",
    "        writer.add_scalar('Loss/test', loss_val, epoch)\n",
    "        writer.add_scalar('Accuracy/train', acc_train, epoch)\n",
    "        writer.add_scalar('Accuracy/test', acc_val, epoch)\n",
    "        \n",
    "    writer.close()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [],
   "source": [
    "val_preds = torch.tensor([], dtype=torch.long).to(device)\n",
    "val_labels = torch.tensor([], dtype=torch.long).to(device)\n",
    "\n",
    "with torch.no_grad():\n",
    "    for images, labels in (val_loader):\n",
    "        images, labels = images.to(device), labels.to(device)\n",
    "        preds = model(images).argmax(dim=1)\n",
    "        val_preds = torch.cat((val_preds, preds.type(torch.long)), dim=0)\n",
    "        val_labels = torch.cat((val_labels, labels.type(torch.long)), dim=0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [],
   "source": [
    "val_preds = val_preds.cpu()\n",
    "val_labels = val_labels.cpu()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([[  0, 736,   0,   0,   0, 264,   0,   0,   0,   0],\n",
      "        [  0, 747,   0,   0,   0, 253,   0,   0,   0,   0],\n",
      "        [  0, 633,   0,   0,   0, 367,   0,   0,   0,   0],\n",
      "        [  0, 780,   0,   0,   0, 220,   0,   0,   0,   0],\n",
      "        [  0, 836,   0,   0,   0, 164,   0,   0,   0,   0],\n",
      "        [  0,  48,   0,   0,   0, 952,   0,   0,   0,   0],\n",
      "        [  0, 613,   0,   0,   0, 387,   0,   0,   0,   0],\n",
      "        [  0, 269,   0,   0,   0, 731,   0,   0,   0,   0],\n",
      "        [  0, 545,   0,   0,   0, 455,   0,   0,   0,   0],\n",
      "        [  0,  98,   0,   0,   0, 902,   0,   0,   0,   0]])\n"
     ]
    }
   ],
   "source": [
    "def confusion_matrix(preds, labels):\n",
    "    stacked = torch.stack((val_labels, val_preds), dim=1)\n",
    "\n",
    "    cmt = torch.zeros(10, 10, dtype=torch.int64)\n",
    "    for p in stacked:\n",
    "        j, k = p.tolist()\n",
    "        cmt[j, k] += 1\n",
    "    return cmt\n",
    "\n",
    "cmt = confusion_matrix(val_preds, val_labels)\n",
    "print(cmt)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "metadata": {},
   "outputs": [],
   "source": [
    "from plot_confusion_matrix import plot_confusion_matrix"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "names = ('T-shirt/top' ,'Trouser' ,'Pullover' ,'Dress' ,'Coat' ,'Sandal' ,'Shirt' ,'Sneaker' ,'Bag' ,'Ankle boot')\n",
    "\n",
    "plot_confusion_matrix2(cmt, names, normalize=True)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "open-mmlab",
   "language": "python",
   "name": "open-mmlab"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
